#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2010 Marcel Martin <marcel.martin@tu-dortmund.de>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""%prog [options] <FASTA/FASTQ FILE> [<QUALITY FILE>]

Reads a FASTA or FASTQ file, finds and removes adapters,
and writes the changed sequence to standard output.
When finished, statistics are printed to standard error.

If two file names are given, they are assumed to be
.csfasta and .qual files as produced by the SOLiD sequencer.
(You still need to provide the -c option to correctly deal
with color space.)

If the name of any input or output file ends with '.gz', it is
assumed to be gzip-compressed.

If you want to search for the reverse complement of an adapter, you must
provide an additional adapter sequence using two '-a' parameters.

If the input sequences are in color space, the adapter must
also be provided in color space (using a string of digits 0123).

EXAMPLE

Assuming your sequencing data is available as a FASTQ file, use this
command line:
$ cutadapt -e ERROR-RATE -a ADAPTER-SEQUENCE input.fastq > output.fastq"""

from __future__ import print_function, division

__version__ = '0.9'

import sys
import re
import gzip
import time
from string import maketrans
from optparse import OptionParser, OptionGroup
from itertools import izip
from contextlib import closing
from collections import defaultdict

import align
import fasta

# constants for the find_best_alignment function
BACK = align.START_WITHIN_SEQ2 | align.STOP_WITHIN_SEQ2 | align.STOP_WITHIN_SEQ1
ANYWHERE = align.SEMIGLOBAL

# for double-encoding colorspace sequences
DOUBLE_ENCODE_TRANS = maketrans('0123.', 'ACGTN')


class FormatError(Exception):
	"""Exception raised for parse errors in the input files"""
	pass


class UnknownFileType(Exception):
	"""Raised by fastafiletype function if file type is unknown"""
	pass


class HelpfulOptionParser(OptionParser):
	"""An OptionParser that prints full help on errors."""
	def error(self, msg):
		self.print_help(sys.stderr)
		self.exit(2, "\n%s: error: %s\n" % (self.get_prog_name(), msg))


def xopen(filename, mode='r', is_closing=True):
	"""
	Replacement for the "open" function that can also open
	files that have been compressed with gzip. If the filename ends with .gz,
	the file is opened with gzip.open(). If it doesn't, the regular open()
	is used. If the filename is '-', standard output (mode 'r') or input
	(mode 'w') is returned.
	closing -- whether to wrap the returned GzipFile with contextlib.closing
		to make it usable in 'with' statements. Don't use when filename may be '-'.
		(TODO look for a nicer solution)
	"""
	if filename == '-':
		return sys.stdout if mode == 'r' else sys.stdin
	if filename.endswith('.gz'):
		if is_closing:
			return closing(gzip.open(filename, mode))
		else:
			return gzip.open(filename, mode)
	else:
		return open(filename, mode)


def quality_trim_index(qualities, cutoff):
	"""
	Find the position at which to trim a low-quality end from a nucleotide sequence.

	Qualities are assumed to be ASCII-encoded as chr(qual + 33).

	>>> trim("", "", 10) TODO
	with qualities based on bwa_trim_read

	This is a Python version of the BWA function 'bwa_trim_read'
	"""
	s = 0
	max_qual = 0
	max_i = len(qualities)
	for i in xrange(len(qualities)-1, -1, -1):
		q = ord(qualities[i]) - 33
		s += cutoff - q
		if s < 0:
			break
		if s > max_qual:
			max_qual = s
			max_i = i
	return max_i


def print_histogram(d):
	"""d -- a dictionary mapping values to their respective frequency"""
	print("length", "count", sep="\t")
	for key in sorted(d):
		print(key, d[key], sep="\t")
	print()


class Statistics(object):
	"""Store statistics about reads and adapters"""

	def __init__(self, adapters):
		self.reads_changed = 0
		self.too_short = 0
		self.too_long = 0
		self.n = 0
		self._start_time = time.clock()
		self.time = None
		self.lengths_front = []
		self.lengths_back = []
		self.adapters = adapters
		for a in adapters:
			self.lengths_front.append(defaultdict(int))
			self.lengths_back.append(defaultdict(int))

	def stop_clock(self):
		"""Stop the timer that was automatically started when the class was instantiated."""
		self.time = time.clock() - self._start_time

	def print_statistics(self, error_rate):
		"""Print summary to stdout"""
		if self.time is None:
			self.stop_clock()
		print("cutadapt version", __version__)
		print("Command line parameters:", " ".join(sys.argv[1:]))
		print("Maximum error rate: %.2f%%" % (error_rate * 100.))
		print("   Processed reads:", self.n)
		print("     Trimmed reads:", self.reads_changed, "(%5.1f%%)" % (100. * self.reads_changed / self.n))
		print("   Too short reads:", self.too_short, "(%5.1f%% of processed reads)" % (100. * self.too_short / self.n))
		print("    Too long reads:", self.too_long, "(%5.1f%% of processed reads)" % (100. * self.too_long / self.n))
		print("        Total time: %9.2f s" % self.time)
		print("     Time per read: %9.2f ms" % (1000. * self.time / self.n))
		print()
		for index, (where, adapter) in enumerate(self.adapters):
			total_front = sum(self.lengths_front[index].values())
			total_back = sum(self.lengths_back[index].values())
			total = total_front + total_back
			assert where == ANYWHERE or (where == BACK and total_front == 0)

			print("=" * 3, "Adapter", index+1, "=" * 3)
			print()
			print("Adapter '%s'," % adapter, "length %d," % len(adapter), "was trimmed", total, "times.")
			if where == ANYWHERE:
				print(total_front, "times, it overlapped the 5' end of a read")
				print(total_back, "times, it overlapped the 3' end or was within the read")
				print()
				print("Histogram of adapter lengths (5')")
				print_histogram(self.lengths_front[index])
				print()
				print("Histogram of adapter lengths (3' or within)")
				print_histogram(self.lengths_back[index])
			else:
				print()
				print("Histogram of adapter lengths")
				print_histogram(self.lengths_back[index])


def quality_to_ascii(quality_line, base=33):
	"""
	Convert a string containing qualities given as integer to a string of
	ASCII-encoded qualities.

	base -- ASCII code of quality zero (sensible values are 33 and 64).

	>>> quality_to_ascii("17 4 29 18")
	'2%>3'
	"""
	fields = map(int, quality_line.split())
	qualities = ''.join(chr(q+base) for q in fields)
	return qualities


def find_best_alignment(adapters, seq, max_error_rate, minimum_overlap):
	"""
	Find the best matching adapter.

	adapters -- List of adapter sequences
	seq -- The sequence to which each adapter will be aligned
	where -- Where in the sequence the adapter may be found.
		One of BACK and FRONT_OR_BACK. For both,
		the adapter will also be found if it is in the middle.
	max_error_rate -- Maximum allowed error rate. The error rate is
		the number of errors in the alignment divided by the length
		of the part of the alignment that matches the adapter.

	Return tuple (best_alignment, best_index).

	best_alignment is an alignment as returned by semiglobalalign.
	best_index is the index of the best adapter into the adapters list.
	"""
	best_score = 0
	best_alignment = None
	best_index = None
	for index, (where, adapter) in enumerate(adapters):
		# try to find an exact match first
		pos = seq.find(adapter)
		if pos >= 0:
			alignment = (None, None, 0, len(adapter), pos, pos + len(adapter), 0)
		else:
			alignment = align.globalalign(adapter, seq, where)
		(r1, r2, astart, astop, rstart, rstop, errors) = alignment
		length = astop - astart
		if length < minimum_overlap or errors/length > max_error_rate:
			continue

		# the length of the matching part of the adapter minus errors
		# determines which adapter fits best
		score = length - errors
		if score > best_score:
			best_alignment = alignment
			best_score = score
			best_index = index
	return (best_alignment, best_index)


def fastafiletype(fname):
	"""
	Determine file type of fname. Return the string FASTQ or FASTA or
	raise an UnknownFileType exception.
	"""
	with xopen(fname) as f:
		for line in f:
			if line.startswith('#'):
				continue
			if line.startswith('@'):
				return 'FASTQ'
			if line.startswith('>'):
				return 'FASTA'
			raise UnknownFileType("neither FASTQ nor FASTA")


def write_read(desc, seq, qualities, outfile):
	"""
	Write read in either FASTA or FASTQ format
	(depending on whether qualities is None or not) to outfile
	"""
	if qualities is None:
		# FASTA
		print('>%s\n%s' % (desc, seq), file=outfile)
	else:
		# FASTQ
		print('@%s\n%s\n+\n%s' % (desc, seq, qualities), file=outfile)


def readsolid(seqfile, qualityfile):
	"""
	Read sequences from a .csfasta and a .qual file.
	Since this file format is used only by SOLiD, the data is assumed to be in color space.
	"""
	seq_iter = fasta.readfasta(seqfile)
	quality_iter = fasta.readfasta(qualityfile)
	for (qdesc, qseq), (rdesc, rseq) in izip(quality_iter, seq_iter):
		if qdesc != rdesc:
			raise FormatError("Descriptions in FASTA and quality file don't match (%s and %s)." % (rdesc, qdesc))
		qualities = quality_to_ascii(qseq)
		if len(rseq) == 0:
			raise FormatError("When reading '%s', no sequence was found (at least the initial primer must appear)." % rdesc)
		if len(qualities) != len(rseq) - 1:
			raise FormatError("While reading '%s': expected to find %d quality values, but found %d." % (rdesc, len(rseq) - 1, len(qualities)))
		yield rdesc, rseq, qualities


def read_sequences(seqfilename, qualityfilename, colorspace):
	"""
	Read sequences and (if available) quality information from either:
	* seqfilename in FASTA format (qualityfilename must be None)
	* seqfilename in FASTQ format (qualityfilename must be None)
	* seqfilename in .csfasta format and qualityfilename in .qual format
	  (SOLiD color space)

	Return a generator over tuples (description, sequence, qualities).
	qualities is None if no qualities are available.
	qualities are ASCII-encoded (chr(quality) + 33).
	"""
	ftype = fastafiletype(seqfilename)

	if ftype == 'FASTQ' and qualityfilename is not None:
		raise ValueError("If a FASTQ file is given, no quality file can be provided.")

	with xopen(seqfilename) as seqfile:
		if ftype == 'FASTQ':
			for desc, seq, qualities in fasta.readfastq(seqfile, colorspace=colorspace):
				yield (desc, seq, qualities)
		elif ftype == 'FASTA' and qualityfilename is None:
			for desc, seq in fasta.readfasta(seqfile):
				yield (desc, seq, None)
		else:
			# read from .CSFASTA/.QUAL
			if not colorspace:
				raise ValueError(".csfasta/.qual file found, please specify -c") # TODO enable automatically
			assert ftype == 'FASTA' and qualityfilename is not None and colorspace
			with xopen(qualityfilename) as qualityfile:
				for desc, seq, qualities in readsolid(seqfile, qualityfile):
					yield (desc, seq, qualities)


class ReadFilter(object):
	"""Filter reads according to length and according to whether any adapter matches."""

	def __init__(self, minimum_length, maximum_length, too_short_outfile, statistics):
		self.minimum_length = minimum_length
		self.maximum_length = maximum_length
		self.too_short_outfile = too_short_outfile
		self.statistics = statistics

	def keep(self, desc, seq, qualities):
		"""
		Return whether to keep the given read.
		"""
		if len(seq) < self.minimum_length:
			self.statistics.too_short += 1
			if self.too_short_outfile is not None:
				write_read(desc, seq, qualities, self.too_short_outfile)
			return False
		if len(seq) > self.maximum_length:
			self.statistics.too_long += 1
			return False
		return True


class AdapterCutter(object):
	"""Cut adapters from reads."""

	def __init__(self, options, adapters):
		self.options = options
		self.adapters = adapters
		self.stats = Statistics(adapters)
		# regular expression used to replace "length=..." strings in sequence descriptions
		if self.options.length_tag is not None:
			self.lengthregex = re.compile(r"\b" + self.options.length_tag + r"[0-9]*\b")
		else:
			self.lengthregex = None

	def cut(self, desc, seq, qualities):
		"""
		Cut adapters from a single read.

		seq -- sequence of the read
		desc -- description of the read
		qualities -- quality values of the read

		Return a tuple (seq, desc, qualities, trimmed) with the modified read.
		trimmed is True when any adapter was found and trimmed.
		"""
		self.stats.n += 1

		if __debug__:
			old_length = len(seq)

		# try (possibly more than once) to remove an adapter
		any_adapter_matches = False
		for t in xrange(self.options.times):
			alignment, index = find_best_alignment(self.adapters, seq, self.options.error_rate, self.options.overlap)
			if alignment is None:
				# nothing found
				break

			(r1, r2, astart, astop, rstart, rstop, errors) = alignment
			length = astop - astart
			assert length > 0
			assert errors/length <= self.options.error_rate
			assert length - errors > 0

			any_adapter_matches = True
			where = self.adapters[index][0]

			if where == BACK or (astart == 0 and rstart > 0):
				assert where == ANYWHERE or astart == 0
				# The adapter is at the end of the read or within the read
				if rstop < len(seq):
					# The adapter is within the read
					if self.options.rest_file is not None:
						print(seq[rstop:], file=self.options.rest_file)
				self.stats.lengths_back[index][length] += 1

				if self.options.colorspace:
					# trim one more color if long enough
					rstart = max(0, rstart - 1)
				seq = seq[:rstart]
				if qualities is not None:
					qualities = qualities[:rstart]
					assert len(qualities) == len(seq)
			elif where == ANYWHERE:
				# The adapter is in the beginning of the read (case 4)
				assert rstart == 0
				self.stats.lengths_front[index][length] += 1
				# TODO What should we do in color space?
				seq = seq[rstop:]
				if qualities is not None:
					qualities = qualities[rstop:]
			else:
				assert False

		if __debug__:
			# if an adapter was found, then the read should now be shorter
			assert (not any_adapter_matches) or (len(seq) < old_length)
		if any_adapter_matches: # TODO move to filter class
			self.stats.reads_changed += 1

		# other modifications to the sequence or its description
		# change any length=XX that may appear in the read description (for 454)
		if self.lengthregex is not None and desc.find(self.options.length_tag) >= 0:
			desc = self.lengthregex.sub(self.options.length_tag + str(len(seq)), desc)

		if self.options.strip_f3 and desc.endswith('_F3'):
			desc = desc[:-3]
		desc = self.options.prefix + desc + self.options.suffix
		if self.options.double_encode:
			# convert color space sequence to double-encoded colorspace (using
			# characters ACGTN to represent colors)
			seq = seq.translate(DOUBLE_ENCODE_TRANS)

		if self.options.colorspace:
			assert not qualities or len(seq) == len(qualities)
		return (desc, seq, qualities, any_adapter_matches)


def main():
	"""Main function that evaluates command-line parameters and contains the main loop over all reads."""
	parser = HelpfulOptionParser(usage=__doc__, version=__version__)

	group = OptionGroup(parser, "Options that influence how the adapters are found")
	group.add_option("-a", "--adapter", action="append", dest="adapters",
		help="Sequence of an adapter that was ligated to the 3' end. The adapter itself and anything that follows is trimmed. If multiple -a or -b options are given, only the best matching adapter is trimmed.")
	group.add_option("-b", "--anywhere", action="append", metavar="ADAPTER",
		help="Sequence of an adapter that was ligated to the 5' or 3' end. If the adapter is found within the read or overlapping the 3' end of the read, the behavior is the same as for the -a option. If the adapter overlaps the 5' end (beginning of the read), the initial portion of the read matching the adapter is trimmed, but anything that follows is kept. If multiple -a or -b options are given, only the best matching adapter is trimmed.")
	group.add_option("-e", "--error-rate", type=float, default=0.1,
		help="Maximum allowed error rate (no. of errors divided by the length of the matching region) (default: %default)")
	group.add_option("-n", "--times", type=int, metavar="COUNT", default=1,
		help="Try to remove adapters at most COUNT times. Useful when an adapter gets appended multiple times.")
	group.add_option("-O", "--overlap", type=int, metavar="LENGTH", default=3,
		help="Minimum overlap length. If the overlap between the adapter and the sequence is shorter than LENGTH, the read is not modified.")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Options for filtering of processed reads")
	group.add_option("--discard", "--discard-trimmed", action='store_true', default=False,
		help="Discard reads that contain the adapter instead of trimming them. Also use -O in order to avoid throwing away too many randomly matching reads!")
	group.add_option("-m", "--minimum-length", type=int, default=0, metavar="LENGTH",
		help="Discard trimmed reads that are shorter than LENGTH. Reads that are too short even before adapter removal are also discarded. In colorspace, an initial primer is not counted.")
	group.add_option("-M", "--maximum-length", type=int, default=sys.maxint, metavar="LENGTH",
		help="Discard trimmed reads that are longer than LENGTH. "
			"Reads that are too long even before adapter removal "
			"are also discarded. In colorspace, an initial primer "
			"is not counted.")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Options that influence what gets output to where")
	group.add_option("-o", "--output", default=None, metavar="FILE",
		help="The modified sequences get written to this file. The format is FASTQ if qualities are available, FASTA otherwise. (default: standard output)")
	group.add_option("-r", "--rest-file", default=None, metavar="FILE",
		help="When the adapter matches in the middle of a read, write the rest (after the adapter) into a file. Use - for standard output.")
	group.add_option("--too-short-output", default=None, metavar="FILE",
		help="Write reads that are too short (according to length specified by -m) to FILE. (default: discard reads)")
	group.add_option("--untrimmed-output", default=None, metavar="FILE",
		help="Write reads that do not contain the adapter to FILE, instead "
			"of writing them to the regular output file. (default: output "
			"to same file as trimmed)")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Additional modifications to the reads")
	group.add_option("-q", "--quality-cutoff", type=int, default=0, metavar="CUTOFF",
		help="Trim low-quality ends from reads before adapter removal. "
			"The algorithm is the same as the one used by BWA "
			"(Subtract CUTOFF from all qualities; "
			"compute partial sums from all indices to the end of the "
			"sequence; cut sequence at the index at which the sum "
			"is minimal) (default: %default)")
	group.add_option("-x", "--prefix", default='',
		help="Add this prefix to read names")
	group.add_option("-y", "--suffix", default='',
		help="Add this suffix to read names")
	group.add_option("-c", "--colorspace", action='store_true', default=False,
		help="Colorspace mode: Also trim the color that is adjacent to the found adapter.")
	group.add_option("-d", "--double-encode", action='store_true', default=False,
		help="When in color space, double-encode colors (map 0,1,2,3,4 to A,C,G,T,N).")
	group.add_option("-t", "--trim-primer", action='store_true', default=False,
		help="When in color space, trim primer base and the first color "
			"(which is the transition to the first nucleotide)")
	group.add_option("--strip-f3", action='store_true', default=False,
		help="For color space: Strip the _F3 suffix of read names")
	group.add_option("--maq", "--bwa", action='store_true', default=False,
		help="MAQ/BWA-compatible color space output. This enables -c, -d, -t, --strip-f3 and -y '/1'.")
	group.add_option("--length-tag", default=None, metavar="TAG",
		help="Search for TAG followed by a decimal number in the name of the read "
			"(description/comment field of the FASTA or FASTQ file). Replace the "
			"decimal number with the correct length of the trimmed read. "
			"For example, use --length-comment 'length=' to search for fields "
			"like 'length=123'.")
	parser.add_option_group(group)

	options, args = parser.parse_args()

	if len(args) == 0:
		parser.error("At least one parameter needed: name of a FASTA or FASTQ file.")
	elif len(args) > 2:
		parser.error("Too many parameters.")

	input_filename = args[0]
	quality_filename = None
	if len(args) == 2:
		quality_filename = args[1]

	# default output files (overwritten below)
	trimmed_outfile = sys.stdout # reads with adapters go here
	too_short_outfile = None # too short reads go here
	#too_long_outfile = None # too long reads go here

	if options.output is not None:
		trimmed_outfile = xopen(options.output, 'w', is_closing=False)
	untrimmed_outfile = trimmed_outfile # reads without adapters go here
	if options.untrimmed_output is not None:
		untrimmed_outfile = xopen(options.untrimmed_output, 'w', is_closing=False)
	if options.too_short_output is not None:
		too_short_outfile = xopen(options.too_short_output, 'w', is_closing=False)
	#if options.too_long_output is not None:
		#too_long_outfile = xopen(options.too_long_output, 'w', is_closing=False)

	if options.maq:
		options.colorspace = True
		options.double_encode = True
		options.trim_primer = True
		options.strip_f3 = True
		options.suffix = "/1"
	if options.trim_primer and not options.colorspace:
		parser.error("Trimming the primer makes only sense in color space.")
	if options.double_encode and not options.colorspace:
		parser.error("Double-encoding makes only sense in color space.")
	if options.anywhere and options.colorspace:
		parser.error("Using --anywhere with color space reads is currently not supported  (if you think this may be useful, contact the author).")
	if not (0 <= options.error_rate <= 1.):
		parser.error("The maximum error rate must be between 0 and 1.")

	if options.rest_file is not None:
		options.rest_file = xopen(options.rest_file, 'w', is_closing=False)

	adapters = []
	if options.adapters:
		adapters = [ (BACK, adapter) for adapter in options.adapters ]
	if options.anywhere:
		adapters += [ (ANYWHERE, adapter) for adapter in options.anywhere ]
	del options.adapters
	del options.anywhere

	if not adapters:
		print("You need to provide at least one adapter sequence.", file=sys.stderr)
		return 1

	#total_bases = 0
	#total_quality_trimmed = 0
	cutter = AdapterCutter(options, adapters)
	readfilter = ReadFilter(options.minimum_length, options.maximum_length, too_short_outfile, cutter.stats) # TODO stats?
	try:
		reader = read_sequences(input_filename, quality_filename, colorspace=options.colorspace)
		for desc, seq, qualities in reader:
			# In colorspace, the first character is the last nucleotide of the primer base
			# and the second character encodes the transition from the primer base to the
			# first real base of the read.
			if options.trim_primer:
				seq = seq[2:]
				if qualities is not None:
					qualities = qualities[1:]
				initial = ''
			elif options.colorspace:
				initial = seq[0]
				seq = seq[1:]
			else:
				initial = ''

			#total_bases += len(qualities)
			if options.quality_cutoff > 0:
				index = quality_trim_index(qualities, options.quality_cutoff)
				#total_quality_trimmed += len(qualities) - index
				qualities = qualities[:index]
				seq = seq[:index]

			desc, seq, qualities, trimmed = cutter.cut(desc, seq, qualities)
			if readfilter.keep(desc, seq, qualities):
				seq = initial + seq
				write_read(desc, seq, qualities, trimmed_outfile if trimmed else untrimmed_outfile)
	except FormatError as e:
		print(e, file=sys.stderr)
		return 1
	sys.stdout = sys.stderr
	cutter.stats.print_statistics(options.error_rate)
	sys.stdout = sys.__stdout__
	return 0


if __name__ == '__main__':
	if len(sys.argv) > 1 and sys.argv[1] == '--profile':
		del sys.argv[1]
		import cProfile as profile
		profile.run('main()', 'cutadapt.prof')
	else:
		sys.exit(main())
